{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e3afcc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%set_env PYTORCH_ENABLE_MPS_FALLBACK=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "524620c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp models.nbeatsx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15392f6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "12fa25a4",
   "metadata": {},
   "source": [
    "# NBEATSx"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "1822c9e8",
   "metadata": {},
   "source": [
    "The Neural Basis Expansion Analysis (`NBEATS`) is an `MLP`-based deep neural architecture with backward and forward residual links. The network has two variants: (1) in its interpretable configuration, `NBEATS` sequentially projects the signal into polynomials and harmonic basis to learn trend and seasonality components; (2) in its generic configuration, it substitutes the polynomial and harmonic basis for identity basis and larger network's depth. The Neural Basis Expansion Analysis with Exogenous (`NBEATSx`), incorporates projections to exogenous temporal variables available at the time of the prediction.<br><br> This method proved state-of-the-art performance on the M3, M4, and Tourism Competition datasets, improving accuracy by 3% over the `ESRNN` M4 competition winner. For Electricity Price Forecasting tasks `NBEATSx` model improved accuracy by 20% and 5% over `ESRNN` and `NBEATS`, and 5% on task-specialized architectures.<br><br>**References**<br>-[Boris N. Oreshkin, Dmitri Carpov, Nicolas Chapados, Yoshua Bengio (2019). \"N-BEATS: Neural basis expansion analysis for interpretable time series forecasting\".](https://arxiv.org/abs/1905.10437)<br>-[Kin G. Olivares, Cristian Challu, Grzegorz Marcjasz, Rafał Weron, Artur Dubrawski (2021). \"Neural basis expansion analysis with exogenous variables: Forecasting electricity prices with NBEATSx\".](https://arxiv.org/abs/2104.05522)<br>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "bddd17a6",
   "metadata": {},
   "source": [
    "![Figure 1. Neural Basis Expansion Analysis with Exogenous Variables.](imgs_models/nbeatsx.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25e1718a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import logging\n",
    "import warnings\n",
    "\n",
    "from fastcore.test import test_eq, test_fail\n",
    "from nbdev.showdoc import show_doc\n",
    "from neuralforecast.utils import generate_series\n",
    "from neuralforecast.common._model_checks import check_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "262e6ab4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from typing import Tuple, Optional\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from neuralforecast.losses.pytorch import MAE\n",
    "from neuralforecast.common._base_model import BaseModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "621dd3a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b7a9fae-2c29-47e2-874e-ca1f20bf7040",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "class IdentityBasis(nn.Module):\n",
    "    def __init__(self, backcast_size: int, forecast_size: int, out_features: int = 1):\n",
    "        super().__init__()\n",
    "        self.out_features = out_features\n",
    "        self.forecast_size = forecast_size\n",
    "        self.backcast_size = backcast_size\n",
    "\n",
    "    def forward(self, theta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "        backcast = theta[:, : self.backcast_size]\n",
    "        forecast = theta[:, self.backcast_size :]\n",
    "        forecast = forecast.reshape(len(forecast), -1, self.out_features)\n",
    "        return backcast, forecast\n",
    "\n",
    "\n",
    "class TrendBasis(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        degree_of_polynomial: int,\n",
    "        backcast_size: int,\n",
    "        forecast_size: int,\n",
    "        out_features: int = 1,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.out_features = out_features\n",
    "        polynomial_size = degree_of_polynomial + 1\n",
    "        self.backcast_basis = nn.Parameter(\n",
    "            torch.tensor(\n",
    "                np.concatenate(\n",
    "                    [\n",
    "                        np.power(\n",
    "                            np.arange(backcast_size, dtype=float) / backcast_size, i\n",
    "                        )[None, :]\n",
    "                        for i in range(polynomial_size)\n",
    "                    ]\n",
    "                ),\n",
    "                dtype=torch.float32,\n",
    "            ),\n",
    "            requires_grad=False,\n",
    "        )\n",
    "        self.forecast_basis = nn.Parameter(\n",
    "            torch.tensor(\n",
    "                np.concatenate(\n",
    "                    [\n",
    "                        np.power(\n",
    "                            np.arange(forecast_size, dtype=float) / forecast_size, i\n",
    "                        )[None, :]\n",
    "                        for i in range(polynomial_size)\n",
    "                    ]\n",
    "                ),\n",
    "                dtype=torch.float32,\n",
    "            ),\n",
    "            requires_grad=False,\n",
    "        )\n",
    "\n",
    "    def forward(self, theta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "        polynomial_size = self.forecast_basis.shape[0]  # [polynomial_size, L+H]\n",
    "        backcast_theta = theta[:, :polynomial_size]\n",
    "        forecast_theta = theta[:, polynomial_size:]\n",
    "        forecast_theta = forecast_theta.reshape(\n",
    "            len(forecast_theta), polynomial_size, -1\n",
    "        )\n",
    "        backcast = torch.einsum(\"bp,pt->bt\", backcast_theta, self.backcast_basis)\n",
    "        forecast = torch.einsum(\"bpq,pt->btq\", forecast_theta, self.forecast_basis)\n",
    "        return backcast, forecast\n",
    "\n",
    "class ExogenousBasis(nn.Module):\n",
    "    # Reference: https://github.com/cchallu/nbeatsx\n",
    "    def __init__(self, forecast_size: int):\n",
    "        super().__init__()\n",
    "        self.forecast_size = forecast_size\n",
    "\n",
    "    def forward(self, theta: torch.Tensor, futr_exog: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "        backcast_basis = futr_exog[:, :-self.forecast_size, :].permute(0, 2, 1)\n",
    "        forecast_basis = futr_exog[:, -self.forecast_size:, :].permute(0, 2, 1)\n",
    "        cut_point = forecast_basis.shape[1]\n",
    "        backcast_theta=theta[:, cut_point:]\n",
    "        forecast_theta=theta[:, :cut_point].reshape(\n",
    "            len(theta), cut_point, -1\n",
    "        )\n",
    "     \n",
    "        backcast = torch.einsum('bp,bpt->bt', backcast_theta, backcast_basis)\n",
    "        forecast = torch.einsum('bpq,bpt->btq', forecast_theta, forecast_basis)\n",
    "        \n",
    "        return backcast, forecast\n",
    "\n",
    "class SeasonalityBasis(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        harmonics: int,\n",
    "        backcast_size: int,\n",
    "        forecast_size: int,\n",
    "        out_features: int = 1,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.out_features = out_features\n",
    "        frequency = np.append(\n",
    "            np.zeros(1, dtype=float),\n",
    "            np.arange(harmonics, harmonics / 2 * forecast_size, dtype=float)\n",
    "            / harmonics,\n",
    "        )[None, :]\n",
    "        backcast_grid = (\n",
    "            -2\n",
    "            * np.pi\n",
    "            * (np.arange(backcast_size, dtype=float)[:, None] / forecast_size)\n",
    "            * frequency\n",
    "        )\n",
    "        forecast_grid = (\n",
    "            2\n",
    "            * np.pi\n",
    "            * (np.arange(forecast_size, dtype=float)[:, None] / forecast_size)\n",
    "            * frequency\n",
    "        )\n",
    "\n",
    "        backcast_cos_template = torch.tensor(\n",
    "            np.transpose(np.cos(backcast_grid)), dtype=torch.float32\n",
    "        )\n",
    "        backcast_sin_template = torch.tensor(\n",
    "            np.transpose(np.sin(backcast_grid)), dtype=torch.float32\n",
    "        )\n",
    "        backcast_template = torch.cat(\n",
    "            [backcast_cos_template, backcast_sin_template], dim=0\n",
    "        )\n",
    "\n",
    "        forecast_cos_template = torch.tensor(\n",
    "            np.transpose(np.cos(forecast_grid)), dtype=torch.float32\n",
    "        )\n",
    "        forecast_sin_template = torch.tensor(\n",
    "            np.transpose(np.sin(forecast_grid)), dtype=torch.float32\n",
    "        )\n",
    "        forecast_template = torch.cat(\n",
    "            [forecast_cos_template, forecast_sin_template], dim=0\n",
    "        )\n",
    "\n",
    "        self.backcast_basis = nn.Parameter(backcast_template, requires_grad=False)\n",
    "        self.forecast_basis = nn.Parameter(forecast_template, requires_grad=False)\n",
    "\n",
    "    def forward(self, theta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "        harmonic_size = self.forecast_basis.shape[0]  # [harmonic_size, L+H]\n",
    "        backcast_theta = theta[:, :harmonic_size]\n",
    "        forecast_theta = theta[:, harmonic_size:]\n",
    "        forecast_theta = forecast_theta.reshape(len(forecast_theta), harmonic_size, -1)\n",
    "        backcast = torch.einsum(\"bp,pt->bt\", backcast_theta, self.backcast_basis)\n",
    "        forecast = torch.einsum(\"bpq,pt->btq\", forecast_theta, self.forecast_basis)\n",
    "        return backcast, forecast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17382790-7d84-4a89-959b-5676afa46392",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "ACTIVATIONS = [\"ReLU\", \"Softplus\", \"Tanh\", \"SELU\", \"LeakyReLU\", \"PReLU\", \"Sigmoid\"]\n",
    "\n",
    "\n",
    "class NBEATSBlock(nn.Module):\n",
    "    \"\"\"\n",
    "    N-BEATS block which takes a basis function as an argument.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_size: int,\n",
    "        h: int,\n",
    "        futr_input_size: int,\n",
    "        hist_input_size: int,\n",
    "        stat_input_size: int,\n",
    "        n_theta: int,\n",
    "        mlp_units: list,\n",
    "        basis: nn.Module,\n",
    "        dropout_prob: float,\n",
    "        activation: str,\n",
    "    ):\n",
    "        \"\"\" \"\"\"\n",
    "        super().__init__()\n",
    "\n",
    "        self.h = h\n",
    "        self.dropout_prob = dropout_prob\n",
    "        self.input_size = input_size\n",
    "        self.futr_input_size = futr_input_size\n",
    "        self.hist_input_size = hist_input_size\n",
    "        self.stat_input_size = stat_input_size\n",
    "\n",
    "        assert activation in ACTIVATIONS, f\"{activation} is not in {ACTIVATIONS}\"\n",
    "        activ = getattr(nn, activation)()\n",
    "\n",
    "        # Input vector for the block is\n",
    "        # y_lags (input_size) + historical exogenous (hist_input_size*input_size) +\n",
    "        # future exogenous (futr_input_size*input_size) + static exogenous (stat_input_size)\n",
    "        # [ Y_[t-L:t], X_[t-L:t], F_[t-L:t+H], S ]\n",
    "        input_size = (\n",
    "            input_size\n",
    "            + hist_input_size * input_size\n",
    "            + futr_input_size * (input_size + h)\n",
    "            + stat_input_size\n",
    "        )\n",
    "\n",
    "        hidden_layers = [\n",
    "            nn.Linear(in_features=input_size, out_features=mlp_units[0][0])\n",
    "        ]\n",
    "        for layer in mlp_units:\n",
    "            hidden_layers.append(nn.Linear(in_features=layer[0], out_features=layer[1]))\n",
    "            hidden_layers.append(activ)\n",
    "\n",
    "            if self.dropout_prob > 0:\n",
    "                hidden_layers.append(nn.Dropout(p=self.dropout_prob))\n",
    "\n",
    "        output_layer = [nn.Linear(in_features=mlp_units[-1][1], out_features=n_theta)]\n",
    "        layers = hidden_layers + output_layer\n",
    "        self.layers = nn.Sequential(*layers)\n",
    "        self.basis = basis\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        insample_y: torch.Tensor,\n",
    "        futr_exog: torch.Tensor,\n",
    "        hist_exog: torch.Tensor,\n",
    "        stat_exog: torch.Tensor,\n",
    "    ) -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "        # Flatten MLP inputs [B, L+H, C] -> [B, (L+H)*C]\n",
    "        # Contatenate [ Y_t, | X_{t-L},..., X_{t} | F_{t-L},..., F_{t+H} | S ]\n",
    "        batch_size = len(insample_y)\n",
    "        if self.hist_input_size > 0:\n",
    "            insample_y = torch.cat(\n",
    "                (insample_y, hist_exog.reshape(batch_size, -1)), dim=1\n",
    "            )\n",
    "\n",
    "        if self.futr_input_size > 0:\n",
    "            insample_y = torch.cat(\n",
    "                (insample_y, futr_exog.reshape(batch_size, -1)), dim=1\n",
    "            )\n",
    "\n",
    "        if self.stat_input_size > 0:\n",
    "            insample_y = torch.cat(\n",
    "                (insample_y, stat_exog.reshape(batch_size, -1)), dim=1\n",
    "            )\n",
    "\n",
    "        # Compute local projection weights and projection\n",
    "        theta = self.layers(insample_y)\n",
    "\n",
    "        if isinstance(self.basis, ExogenousBasis):\n",
    "            if self.futr_input_size > 0 and self.stat_input_size > 0:                \n",
    "                futr_exog = torch.cat(\n",
    "                    (\n",
    "                        futr_exog,\n",
    "                        stat_exog.unsqueeze(1).expand(-1, futr_exog.shape[1], -1)\n",
    "                    ),\n",
    "                    dim=2\n",
    "                )\n",
    "            elif self.futr_input_size >0:\n",
    "                futr_exog = futr_exog\n",
    "            elif self.stat_input_size >0:\n",
    "                futr_exog = stat_exog.unsqueeze(1).expand(-1, self.input_size+self.h, -1)\n",
    "            else:\n",
    "                raise(ValueError(\"No stats or future exogenous. ExogenousBlock not supported.\"))    \n",
    "            backcast, forecast = self.basis(theta, futr_exog)\n",
    "            return backcast, forecast\n",
    "        else:\n",
    "            backcast, forecast = self.basis(theta)\n",
    "            return backcast, forecast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be997aeb-778f-442d-a97a-ff47de2deab6",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class NBEATSx(BaseModel):\n",
    "    \"\"\"NBEATSx\n",
    "\n",
    "    The Neural Basis Expansion Analysis with Exogenous variables (NBEATSx) is a simple\n",
    "    and effective deep learning architecture. It is built with a deep stack of MLPs with\n",
    "    doubly residual connections. The NBEATSx architecture includes additional exogenous\n",
    "    blocks, extending NBEATS capabilities and interpretability. With its interpretable\n",
    "    version, NBEATSx decomposes its predictions on seasonality, trend, and exogenous effects.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `h`: int, Forecast horizon. <br>\n",
    "    `input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].<br>\n",
    "    `futr_exog_list`: str list, future exogenous columns.<br>\n",
    "    `hist_exog_list`: str list, historic exogenous columns.<br>\n",
    "    `stat_exog_list`: str list, static exogenous columns.<br>\n",
    "    `exclude_insample_y`: bool=False, the model skips the autoregressive features y[t-input_size:t] if True.<br>\n",
    "    `n_harmonics`: int, Number of harmonic oscillations in the SeasonalityBasis [cos(i * t/n_harmonics), sin(i * t/n_harmonics)]. Note that it will only be used if 'seasonality' is in `stack_types`.<br>\n",
    "    `n_polynomials`: int, Number of polynomial terms for TrendBasis [1,t,...,t^n_poly]. Note that it will only be used if 'trend' is in `stack_types`.<br>\n",
    "    `stack_types`: List[str], List of stack types. Subset from ['seasonality', 'trend', 'identity', 'exogenous'].<br>\n",
    "    `n_blocks`: List[int], Number of blocks for each stack. Note that len(n_blocks) = len(stack_types).<br>\n",
    "    `mlp_units`: List[List[int]], Structure of hidden layers for each stack type. Each internal list should contain the number of units of each hidden layer. Note that len(n_hidden) = len(stack_types).<br>\n",
    "    `dropout_prob_theta`: float, Float between (0, 1). Dropout for N-BEATS basis.<br>\n",
    "    `activation`: str, activation from ['ReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'PReLU', 'Sigmoid'].<br>\n",
    "    `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `max_steps`: int=1000, maximum number of training steps.<br>\n",
    "    `learning_rate`: float=1e-3, Learning rate between (0, 1).<br>\n",
    "    `num_lr_decays`: int=3, Number of learning rate decays, evenly distributed across max_steps.<br>\n",
    "    `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.<br>\n",
    "    `val_check_steps`: int=100, Number of training steps between every validation loss check.<br>\n",
    "    `batch_size`: int=32, number of different series in each batch.<br>\n",
    "    `valid_batch_size`: int=None, number of different series in each validation and test batch, if None uses batch_size.<br>\n",
    "    `windows_batch_size`: int=1024, number of windows to sample in each training batch, default uses all.<br>\n",
    "    `inference_windows_batch_size`: int=-1, number of windows to sample in each inference batch, -1 uses all.<br>\n",
    "    `start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.<br>\n",
    "    `step_size`: int=1, step size between each window of temporal data.<br>\n",
    "    `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>\n",
    "    `random_seed`: int, random seed initialization for replicability.<br>\n",
    "    `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
    "    `alias`: str, optional,  Custom name of the model.<br>\n",
    "    `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>\n",
    "    `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
    "    `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
    "    `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
    "    `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
    "    `**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>\n",
    "\n",
    "    **References:**<br>\n",
    "    -[Kin G. Olivares, Cristian Challu, Grzegorz Marcjasz, Rafał Weron, Artur Dubrawski (2021).\n",
    "    \"Neural basis expansion analysis with exogenous variables: Forecasting electricity prices with NBEATSx\".](https://arxiv.org/abs/2104.05522)\n",
    "    \"\"\"\n",
    "\n",
    "    # Class attributes\n",
    "    EXOGENOUS_FUTR = True\n",
    "    EXOGENOUS_HIST = True\n",
    "    EXOGENOUS_STAT = True\n",
    "    MULTIVARIATE = False    # If the model produces multivariate forecasts (True) or univariate (False)\n",
    "    RECURRENT = False       # If the model produces forecasts recursively (True) or direct (False)\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        h,\n",
    "        input_size,\n",
    "        futr_exog_list=None,\n",
    "        hist_exog_list=None,\n",
    "        stat_exog_list=None,\n",
    "        exclude_insample_y=False,\n",
    "        n_harmonics=2,\n",
    "        n_polynomials=2,\n",
    "        stack_types: list = [\"identity\", \"trend\", \"seasonality\"],\n",
    "        n_blocks: list = [1, 1, 1],\n",
    "        mlp_units: list = 3 * [[512, 512]],\n",
    "        dropout_prob_theta=0.0,\n",
    "        activation=\"ReLU\",\n",
    "        shared_weights=False,\n",
    "        loss=MAE(),\n",
    "        valid_loss=None,\n",
    "        max_steps: int = 1000,\n",
    "        learning_rate: float = 1e-3,\n",
    "        num_lr_decays: int = 3,\n",
    "        early_stop_patience_steps: int = -1,\n",
    "        val_check_steps: int = 100,\n",
    "        batch_size=32,\n",
    "        valid_batch_size: Optional[int] = None,\n",
    "        windows_batch_size: int = 1024,\n",
    "        inference_windows_batch_size: int = -1,\n",
    "        start_padding_enabled: bool = False,\n",
    "        step_size: int = 1,\n",
    "        scaler_type: str = \"identity\",\n",
    "        random_seed: int = 1,\n",
    "        drop_last_loader: bool = False,\n",
    "        alias: Optional[str] = None,\n",
    "        optimizer = None,\n",
    "        optimizer_kwargs = None,\n",
    "        lr_scheduler = None,\n",
    "        lr_scheduler_kwargs = None,\n",
    "        dataloader_kwargs = None,\n",
    "        **trainer_kwargs,\n",
    "    ):\n",
    "        # Protect horizon collapsed seasonality and trend NBEATSx-i basis\n",
    "        if h == 1 and ((\"seasonality\" in stack_types) or (\"trend\" in stack_types)):\n",
    "            raise Exception(\n",
    "                \"Horizon `h=1` incompatible with `seasonality` or `trend` in stacks\"\n",
    "            )\n",
    "\n",
    "        # Inherit BaseWindows class\n",
    "        super(NBEATSx, self).__init__(h=h, \n",
    "                                      input_size = input_size,\n",
    "                                      futr_exog_list=futr_exog_list,\n",
    "                                      hist_exog_list=hist_exog_list,\n",
    "                                      stat_exog_list=stat_exog_list,\n",
    "                                      exclude_insample_y=exclude_insample_y,                                      \n",
    "                                      loss=loss,\n",
    "                                      valid_loss=valid_loss,\n",
    "                                      max_steps=max_steps,\n",
    "                                      learning_rate=learning_rate,\n",
    "                                      num_lr_decays=num_lr_decays,\n",
    "                                      early_stop_patience_steps=early_stop_patience_steps,\n",
    "                                      val_check_steps=val_check_steps,\n",
    "                                      batch_size=batch_size,\n",
    "                                      valid_batch_size=valid_batch_size,\n",
    "                                      windows_batch_size = windows_batch_size,\n",
    "                                      inference_windows_batch_size=inference_windows_batch_size,\n",
    "                                      start_padding_enabled=start_padding_enabled,\n",
    "                                      step_size = step_size,\n",
    "                                      scaler_type=scaler_type,\n",
    "                                      random_seed=random_seed,\n",
    "                                      drop_last_loader=drop_last_loader,\n",
    "                                      alias=alias,\n",
    "                                      optimizer=optimizer,\n",
    "                                      optimizer_kwargs=optimizer_kwargs,\n",
    "                                      lr_scheduler=lr_scheduler,\n",
    "                                      lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
    "                                      dataloader_kwargs=dataloader_kwargs,\n",
    "                                      **trainer_kwargs)\n",
    "\n",
    "        # Architecture\n",
    "        blocks = self.create_stack(\n",
    "            h=h,\n",
    "            input_size=input_size,\n",
    "            futr_input_size=self.futr_exog_size,\n",
    "            hist_input_size=self.hist_exog_size,\n",
    "            stat_input_size=self.stat_exog_size,\n",
    "            stack_types=stack_types,\n",
    "            n_blocks=n_blocks,\n",
    "            mlp_units=mlp_units,\n",
    "            dropout_prob_theta=dropout_prob_theta,\n",
    "            activation=activation,\n",
    "            shared_weights=shared_weights,\n",
    "            n_polynomials=n_polynomials,\n",
    "            n_harmonics=n_harmonics,\n",
    "        )\n",
    "        self.blocks = torch.nn.ModuleList(blocks)\n",
    "\n",
    "    def create_stack(\n",
    "        self,\n",
    "        h,\n",
    "        input_size,\n",
    "        stack_types,\n",
    "        n_blocks,\n",
    "        mlp_units,\n",
    "        dropout_prob_theta,\n",
    "        activation,\n",
    "        shared_weights,\n",
    "        n_polynomials,\n",
    "        n_harmonics,\n",
    "        futr_input_size,\n",
    "        hist_input_size,\n",
    "        stat_input_size,\n",
    "    ):\n",
    "        block_list = []\n",
    "        for i in range(len(stack_types)):\n",
    "            for block_id in range(n_blocks[i]):\n",
    "                # Shared weights\n",
    "                if shared_weights and block_id > 0:\n",
    "                    nbeats_block = block_list[-1]\n",
    "                else:\n",
    "                    if stack_types[i] == \"seasonality\":\n",
    "                        n_theta = (\n",
    "                            2\n",
    "                            * (self.loss.outputsize_multiplier + 1)\n",
    "                            * int(np.ceil(n_harmonics / 2 * h) - (n_harmonics - 1))\n",
    "                        )\n",
    "                        basis = SeasonalityBasis(\n",
    "                            harmonics=n_harmonics,\n",
    "                            backcast_size=input_size,\n",
    "                            forecast_size=h,\n",
    "                            out_features=self.loss.outputsize_multiplier,\n",
    "                        )\n",
    "\n",
    "                    elif stack_types[i] == \"trend\":\n",
    "                        n_theta = (self.loss.outputsize_multiplier + 1) * (\n",
    "                            n_polynomials + 1\n",
    "                        )\n",
    "                        basis = TrendBasis(\n",
    "                            degree_of_polynomial=n_polynomials,\n",
    "                            backcast_size=input_size,\n",
    "                            forecast_size=h,\n",
    "                            out_features=self.loss.outputsize_multiplier,\n",
    "                        )\n",
    "\n",
    "                    elif stack_types[i] == \"identity\":\n",
    "                        n_theta = input_size + self.loss.outputsize_multiplier * h\n",
    "                        basis = IdentityBasis(\n",
    "                            backcast_size=input_size,\n",
    "                            forecast_size=h,\n",
    "                            out_features=self.loss.outputsize_multiplier,\n",
    "                        )\n",
    "\n",
    "                    elif stack_types[i] == \"exogenous\":\n",
    "                        if futr_input_size + stat_input_size > 0:\n",
    "                            n_theta = 2*(\n",
    "                                futr_input_size + stat_input_size\n",
    "                            )\n",
    "                            basis = ExogenousBasis(forecast_size=h)\n",
    "\n",
    "                    else:\n",
    "                        raise ValueError(f\"Block type {stack_types[i]} not found!\")\n",
    "\n",
    "                    nbeats_block = NBEATSBlock(\n",
    "                        input_size=input_size,\n",
    "                        h=h,\n",
    "                        futr_input_size=futr_input_size,\n",
    "                        hist_input_size=hist_input_size,\n",
    "                        stat_input_size=stat_input_size,\n",
    "                        n_theta=n_theta,\n",
    "                        mlp_units=mlp_units,\n",
    "                        basis=basis,\n",
    "                        dropout_prob=dropout_prob_theta,\n",
    "                        activation=activation,\n",
    "                    )\n",
    "\n",
    "                # Select type of evaluation and apply it to all layers of block\n",
    "                block_list.append(nbeats_block)\n",
    "\n",
    "        return block_list\n",
    "\n",
    "    def forward(self, windows_batch):\n",
    "        # Parse windows_batch\n",
    "        insample_y = windows_batch[\"insample_y\"].squeeze(-1)\n",
    "        insample_mask = windows_batch[\"insample_mask\"].squeeze(-1)\n",
    "        futr_exog = windows_batch[\"futr_exog\"]\n",
    "        hist_exog = windows_batch[\"hist_exog\"]\n",
    "        stat_exog = windows_batch[\"stat_exog\"]\n",
    "\n",
    "        # NBEATSx' forward\n",
    "        residuals = insample_y.flip(dims=(-1,))  # backcast init\n",
    "        insample_mask = insample_mask.flip(dims=(-1,))\n",
    "\n",
    "        forecast = insample_y[:, -1:, None]  # Level with Naive1\n",
    "        block_forecasts = [forecast.repeat(1, self.h, 1)]\n",
    "        for i, block in enumerate(self.blocks):\n",
    "            backcast, block_forecast = block(\n",
    "                insample_y=residuals,\n",
    "                futr_exog=futr_exog,\n",
    "                hist_exog=hist_exog,\n",
    "                stat_exog=stat_exog,\n",
    "            )\n",
    "            residuals = (residuals - backcast) * insample_mask\n",
    "            forecast = forecast + block_forecast\n",
    "\n",
    "            if self.decompose_forecast:\n",
    "                block_forecasts.append(block_forecast)\n",
    "\n",
    "        if self.decompose_forecast:\n",
    "            # (n_batch, n_blocks, h)\n",
    "            block_forecasts = torch.stack(block_forecasts)\n",
    "            block_forecasts = block_forecasts.permute(1, 0, 2, 3)\n",
    "            block_forecasts = block_forecasts.squeeze(-1)  # univariate output\n",
    "            return block_forecasts\n",
    "        else:\n",
    "            return forecast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c57a831f-94bc-4616-b579-c114c3fc57c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(NBEATSx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9013b63-f65b-4a92-913c-b696e6e69914",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(NBEATSx.fit, name='NBEATSx.fit')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a66184ee-7a71-4598-976c-c79b83089a6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(NBEATSx.predict, name='NBEATSx.predict')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce8cba7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# Unit tests for models\n",
    "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n",
    "logging.getLogger(\"lightning_fabric\").setLevel(logging.ERROR)\n",
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter(\"ignore\")\n",
    "    check_model(NBEATSx, [\"airpassengers\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "360618d9-1049-4172-af97-aff19335f78b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from neuralforecast.losses.pytorch import MQLoss\n",
    "from neuralforecast.tsdataset import TimeSeriesDataset\n",
    "from neuralforecast.utils import AirPassengersDF as Y_df\n",
    "from neuralforecast.utils import AirPassengersStatic as Y_static"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fa496f4-d7f1-4fdc-87d6-ea497f853cc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# Month\n",
    "Y_df['month'] = Y_df['ds'].dt.month\n",
    "Y_df['year'] = Y_df['ds'].dt.year\n",
    "\n",
    "Y_train_df = Y_df[Y_df.ds<Y_df['ds'].values[-12]] # 132 train\n",
    "Y_test_df = Y_df[Y_df.ds>=Y_df['ds'].values[-12]]   # 12 test\n",
    "\n",
    "dataset, *_ = TimeSeriesDataset.from_df(df = Y_train_df)\n",
    "model = NBEATSx(h=12,\n",
    "                input_size=24,\n",
    "                scaler_type='robust',\n",
    "                stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],\n",
    "                n_blocks = [1,1,1,1],\n",
    "                futr_exog_list=['month','year'],\n",
    "                windows_batch_size=None,\n",
    "                max_steps=1)\n",
    "model.fit(dataset=dataset)\n",
    "dataset2 = dataset.update_dataset(dataset, Y_test_df)\n",
    "model.set_test_size(12)\n",
    "y_hat = model.predict(dataset=dataset2)\n",
    "Y_test_df['NBEATSx'] = y_hat\n",
    "\n",
    "pd.concat([Y_train_df, Y_test_df]).drop(['unique_id','month'], axis=1).set_index('ds').plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db94b63e-d82c-423f-8f75-184ae285904d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#test we recover the same forecast\n",
    "y_hat2 = model.predict(dataset=dataset2)\n",
    "test_eq(y_hat, y_hat2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46090447-8e67-4f08-8a3d-9547183983f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#test no leakage with test_size\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_df)\n",
    "model = NBEATSx(h=12,\n",
    "                input_size=24,\n",
    "                scaler_type='robust',\n",
    "                stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],\n",
    "                n_blocks = [1,1,1,1],\n",
    "                futr_exog_list=['month','year'],\n",
    "                windows_batch_size=None,\n",
    "                max_steps=1)\n",
    "model.fit(dataset=dataset, test_size=12)\n",
    "y_hat_test = model.predict(dataset=dataset, step_size=1)\n",
    "np.testing.assert_almost_equal(y_hat, y_hat_test, decimal=4)\n",
    "#test we recover the same forecast\n",
    "y_hat_test2 = model.predict(dataset=dataset, step_size=1)\n",
    "test_eq(y_hat_test, y_hat_test2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a87652b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#test no leakage with test_size\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_df)\n",
    "model = NBEATSx(h=12,\n",
    "                input_size=24,\n",
    "                scaler_type='robust',\n",
    "                stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],\n",
    "                n_blocks = [1,1,1,1],\n",
    "                futr_exog_list=['month','year'],\n",
    "                windows_batch_size=None,\n",
    "                max_steps=1)\n",
    "model.fit(dataset=dataset, test_size=12)\n",
    "y_hat_test = model.predict(dataset=dataset, step_size=1)\n",
    "np.testing.assert_almost_equal(y_hat, y_hat_test, decimal=4)\n",
    "#test we recover the same forecast\n",
    "y_hat_test2 = model.predict(dataset=dataset, step_size=1)\n",
    "test_eq(y_hat_test, y_hat_test2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02d7648e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# test seasonality/trend basis protection\n",
    "test_fail(NBEATSx.__init__, \n",
    "          contains='Horizon `h=1` incompatible with `seasonality` or `trend` in stacks',\n",
    "          kwargs=dict(self=BaseModel, h=1, input_size=4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fde23c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# test inference_windows_batch_size\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_df)\n",
    "model = NBEATSx(h=12,\n",
    "                input_size=24,\n",
    "                scaler_type='robust',\n",
    "                stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],\n",
    "                n_blocks = [1,1,1,1],\n",
    "                futr_exog_list=['month','year'],\n",
    "                windows_batch_size=None,\n",
    "                inference_windows_batch_size=1,\n",
    "                max_steps=1)\n",
    "model.fit(dataset=dataset, test_size=12)\n",
    "y_hat_test = model.predict(dataset=dataset, step_size=1)\n",
    "#test we recover the same forecast with different inference_windows_batch_size\n",
    "model.inference_windows_batch_size=-1\n",
    "y_hat_test2 = model.predict(dataset=dataset, step_size=1)\n",
    "test_eq(y_hat_test, y_hat_test2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8501c70",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# Check val_check_steps protection to less than max_steps\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_df)\n",
    "model = NBEATSx(h=12,\n",
    "                input_size=24,\n",
    "                scaler_type='robust',\n",
    "                stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],\n",
    "                n_blocks = [1,1,1,1],\n",
    "                futr_exog_list=['month','year'],\n",
    "                windows_batch_size=None,\n",
    "                early_stop_patience_steps=1,\n",
    "                max_steps=1,\n",
    "                val_check_steps=5\n",
    "                )\n",
    "model.fit(dataset=dataset, test_size=12, val_size=12)\n",
    "test_eq(model.trainer_kwargs['val_check_interval'], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da9b51bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# test using the ExogenousBasis with both static and future exogenous variables\n",
    "dataset, *_ = TimeSeriesDataset.from_df(df = Y_train_df, static_df=Y_static)\n",
    "model = NBEATSx(h=12,\n",
    "                input_size=24,\n",
    "                scaler_type='robust',\n",
    "                stack_types = [\"seasonality\", \"exogenous\"],\n",
    "                n_blocks = [1,1],\n",
    "                futr_exog_list=['month','year'],\n",
    "                stat_exog_list=['airline1', 'airline2'],\n",
    "                windows_batch_size=None,\n",
    "                max_steps=1)\n",
    "model.fit(dataset=dataset)\n",
    "dataset2 = dataset.update_dataset(dataset, Y_test_df)\n",
    "model.set_test_size(12)\n",
    "y_hat = model.predict(dataset=dataset2)\n",
    "assert(len(y_hat)==12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01e8ae90",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "Y_train_df = Y_df[Y_df.ds<Y_df['ds'].values[-12]] # 132 train\n",
    "Y_test_df = Y_df[Y_df.ds>=Y_df['ds'].values[-12]]   # 12 test\n",
    "\n",
    "# Fit MQ-MLP\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_train_df)\n",
    "model = NBEATSx(h=12, input_size=24, max_steps=1,\n",
    "                scaler_type='robust',\n",
    "                stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],\n",
    "                n_blocks = [1,1,1,1],\n",
    "                futr_exog_list=['month','year'],\n",
    "                loss=MQLoss(level=[80, 90]))\n",
    "model.fit(dataset=dataset, val_size=12)\n",
    "\n",
    "# Parse quantile predictions\n",
    "dataset2 = dataset.update_dataset(dataset, Y_test_df)\n",
    "model.set_test_size(12)\n",
    "y_hat = model.predict(dataset=dataset2)\n",
    "Y_hat_df = pd.DataFrame.from_records(data=y_hat,\n",
    "                columns=['NBEATS'+q for q in model.loss.output_names],\n",
    "                index=Y_test_df.index)\n",
    "\n",
    "# Plot quantile predictions\n",
    "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
    "plot_df = pd.concat([Y_train_df, plot_df]).drop('unique_id', axis=1)\n",
    "plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "plt.plot(plot_df['ds'], plot_df['NBEATS-median'], c='blue', label='median')\n",
    "plt.fill_between(x=plot_df['ds'][-12:], \n",
    "                 y1=plot_df['NBEATS-lo-90'][-12:].values, \n",
    "                 y2=plot_df['NBEATS-hi-90'][-12:].values,\n",
    "                 alpha=0.4, label='level 90')\n",
    "plt.grid()\n",
    "plt.legend()\n",
    "plt.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0298fce5-eb13-40dc-9964-b026fd2a8928",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# test validation step\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_train_df)\n",
    "model = NBEATSx(h=12, input_size=24, \n",
    "                windows_batch_size=None, max_steps=1, \n",
    "                scaler_type='robust',\n",
    "                stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],\n",
    "                n_blocks = [1,1,1,1],\n",
    "                futr_exog_list=['month','year'])\n",
    "model.fit(dataset=dataset, val_size=12)\n",
    "dataset2 = dataset.update_dataset(dataset, Y_test_df)\n",
    "model.set_test_size(12)\n",
    "y_hat_w_val = model.predict(dataset=dataset2)\n",
    "Y_test_df['N-BEATS'] = y_hat_w_val\n",
    "\n",
    "pd.concat([Y_train_df, Y_test_df]).drop('unique_id', axis=1).set_index('ds').plot()\n",
    "plt.grid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f987ed0-ee6e-4f66-bd8f-96acc6fbd56c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# test no leakage with test_size and val_size\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_train_df)\n",
    "model = NBEATSx(h=12, input_size=24, windows_batch_size=None, max_steps=1,\n",
    "                scaler_type='robust',stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],n_blocks = [1,1,1,1],futr_exog_list=['month','year'])\n",
    "model.fit(dataset=dataset, val_size=12)\n",
    "dataset2 = dataset.update_dataset(dataset, Y_test_df)\n",
    "model.set_test_size(12)\n",
    "y_hat_w_val = model.predict(dataset=dataset2)\n",
    "\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_df)\n",
    "model = NBEATSx(input_size=24, h=12, windows_batch_size=None, max_steps=1,\n",
    "                scaler_type='robust',stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],n_blocks = [1,1,1,1], futr_exog_list=['month','year'])\n",
    "model.fit(dataset=dataset, val_size=12, test_size=12)\n",
    "\n",
    "y_hat_test_w_val = model.predict(dataset=dataset, step_size=1)\n",
    "\n",
    "np.testing.assert_almost_equal(y_hat_test_w_val, y_hat_w_val, decimal=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "036be4bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# qualitative decomposition evaluation\n",
    "y_hat = model.decompose(dataset=dataset)\n",
    "\n",
    "fig, ax = plt.subplots(6, 1, figsize=(10, 15))\n",
    "\n",
    "ax[0].plot(Y_test_df['y'].values, label='True', color=\"#9C9DB2\", linewidth=4)\n",
    "ax[0].plot(y_hat.sum(axis=1).flatten(), label='Forecast', color=\"#7B3841\")\n",
    "ax[0].grid()\n",
    "ax[0].legend(prop={'size': 20})\n",
    "for label in (ax[0].get_xticklabels() + ax[0].get_yticklabels()):\n",
    "    label.set_fontsize(18)\n",
    "ax[0].set_ylabel('y', fontsize=20)\n",
    "\n",
    "ax[1].plot(y_hat[0,0], label='level', color=\"#7B3841\")\n",
    "ax[1].grid()\n",
    "ax[1].set_ylabel('Level', fontsize=20)\n",
    "\n",
    "ax[2].plot(y_hat[0,1], label='stack1', color=\"#7B3841\")\n",
    "ax[2].grid()\n",
    "ax[2].set_ylabel('Identity', fontsize=20)\n",
    "\n",
    "ax[3].plot(y_hat[0,2], label='stack2', color=\"#D9AE9E\")\n",
    "ax[3].grid()\n",
    "ax[3].set_ylabel('Trend', fontsize=20)\n",
    "\n",
    "ax[4].plot(y_hat[0,3], label='stack3', color=\"#D9AE9E\")\n",
    "ax[4].grid()\n",
    "ax[4].set_ylabel('Seasonality', fontsize=20)\n",
    "\n",
    "ax[5].plot(y_hat[0,4], label='stack4', color=\"#D9AE9E\")\n",
    "ax[5].grid()\n",
    "ax[5].set_ylabel('Exogenous', fontsize=20)\n",
    "\n",
    "ax[5].set_xlabel('Prediction \\u03C4 \\u2208 {t+1,..., t+H}', fontsize=20)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "358dd690",
   "metadata": {},
   "source": [
    "## Usage Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68681f2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.models import NBEATSx\n",
    "from neuralforecast.losses.pytorch import MQLoss\n",
    "from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic\n",
    "\n",
    "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train\n",
    "Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
    "\n",
    "model = NBEATSx(h=12, input_size=24,\n",
    "                loss=MQLoss(level=[80, 90]),\n",
    "                scaler_type='robust',\n",
    "                dropout_prob_theta=0.5,\n",
    "                stat_exog_list=['airline1'],\n",
    "                futr_exog_list=['trend'],\n",
    "                stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],\n",
    "                n_blocks = [1,1,1,1],\n",
    "                max_steps=200,\n",
    "                val_check_steps=10,\n",
    "                early_stop_patience_steps=2)\n",
    "\n",
    "nf = NeuralForecast(\n",
    "    models=[model],\n",
    "    freq='ME'\n",
    ")\n",
    "nf.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n",
    "Y_hat_df = nf.predict(futr_df=Y_test_df)\n",
    "\n",
    "# Plot quantile predictions\n",
    "Y_hat_df = Y_hat_df.reset_index(drop=False).drop(columns=['unique_id','ds'])\n",
    "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
    "plot_df = pd.concat([Y_train_df, plot_df])\n",
    "\n",
    "plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "plt.plot(plot_df['ds'], plot_df['NBEATSx-median'], c='blue', label='median')\n",
    "plt.fill_between(x=plot_df['ds'][-12:], \n",
    "                 y1=plot_df['NBEATSx-lo-90'][-12:].values, \n",
    "                 y2=plot_df['NBEATSx-hi-90'][-12:].values,\n",
    "                 alpha=0.4, label='level 90')\n",
    "plt.legend()\n",
    "plt.grid()\n",
    "plt.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf05ff9a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
